{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvInv8aSxy3A"
      },
      "source": [
        "## ViT-Plex Demo\n",
        "\n",
        "*Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "To run this in a public Colab, change the GitHub link: replace github.com with [githubtocolab.com](http://githubtocolab.com)\n",
        "\n",
        "\u003ca href=\"https://githubtocolab.com/google/uncertainty-baselines/blob/main/experimental/plex/plex_vit_demo.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
        "\n",
        "This notebook demonstrates how one can utilize the released **ViT-Plex** checkpoints from the *Plex: Towards Reliability using Pretrained Large Model Extensions* paper using [JAX](https://jax.readthedocs.io/). The **General usage** section provides a minimal setup for loading the checkpoints and making predictions, and the ***Uncertainty***, ***Robust Generalization***, and ***Adaptation*** sections delve deeper into the three areas of reliability for which Plex is designed to excel.\n",
        "\n",
        "For more advanced usage, full training and fine-tuning scripts can be found at https://github.com/google/uncertainty-baselines/tree/main/baselines/jft."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7HBkyhA_yYBr"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "32ZPMFa7N_3u"
      },
      "outputs": [],
      "source": [
        "# NOTE: Use `tpu-colab` when running on a hosted TPU Colab runtime. Use `tpu`\n",
        "# when running on a GCP TPU machine.\n",
        "backend = \"cpu\"  #@param [\"tpu-colab\", \"tpu\", \"gpu\", \"cpu\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TDl8fDWONRip"
      },
      "outputs": [],
      "source": [
        "pip_install = True\n",
        "if pip_install:\n",
        "  # NOTE: Set the jax version to \u003e=0.3.14 if Python 3.9+ is available.\n",
        "  if backend == \"cpu\" or backend == \"tpu-colab\":\n",
        "    !python3 -m pip install \"jax~=0.2.27\"\n",
        "  elif backend == \"tpu\":\n",
        "    !python3 -m pip install \"jax[tpu]~=0.2.27\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n",
        "  elif backend == \"gpu\":\n",
        "    !python3 -m pip install \"jax[cuda]~=0.2.27\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
        "  else:\n",
        "    raise ValueError(\"Backend must be one of ['cpu', 'tpu', 'gpu']. got \"\n",
        "                     f\"backend={backend} instead.\")\n",
        "  !rm -rf uncertainty-baselines\n",
        "  !git clone https://github.com/google/uncertainty-baselines.git\n",
        "  !cp -r uncertainty-baselines/baselines/jft/* .\n",
        "  # NOTE: Remove the explicit tensorflow-federated and tensorflow_probability\n",
        "  # installs if Python 3.9+ is available.\n",
        "  !python3 -m pip install \"tensorflow-federated==0.20.0\" \"tensorflow_probability\u003c0.17.0\" ./uncertainty-baselines[tensorflow,jax,models,datasets]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0yEmRO5rQ338"
      },
      "outputs": [],
      "source": [
        "if backend == \"tpu-colab\":\n",
        "  import jax.tools.colab_tpu\n",
        "  jax.tools.colab_tpu.setup_tpu()\n",
        "\n",
        "import functools\n",
        "\n",
        "from clu import preprocess_spec\n",
        "import flax\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import ml_collections\n",
        "import sklearn\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import uncertainty_baselines as ub\n",
        "import checkpoint_utils  # local file import from baselines.jft\n",
        "import input_utils  # local file import from baselines.jft\n",
        "import ood_utils  # local file import from baselines.jft\n",
        "import preprocess_utils  # local file import from baselines.jft"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sfgunmx79Cc6"
      },
      "outputs": [],
      "source": [
        "# If running with TPUs, the following should output a list of TPU devices.\n",
        "print(jax.local_devices())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hBa4sr7EWZp0"
      },
      "outputs": [],
      "source": [
        "# Set a base seed to use for the notebook.\n",
        "rng = jax.random.PRNGKey(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GPlHLfmlyao0"
      },
      "source": [
        "## General usage"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zQQcvGcPRnvE"
      },
      "source": [
        "### Load model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qKq6Pvv7S-fs"
      },
      "outputs": [],
      "source": [
        "def get_finetuned_config():\n",
        "  # From `https://github.com/google/uncertainty-baselines/blob/main/baselines/jft/experiments/vit_l32_plex_finetune.py`.\n",
        "  # TODO(dusenberrymw): Clean up this config.\n",
        "  config = ml_collections.ConfigDict()\n",
        "  config.model = ml_collections.ConfigDict()\n",
        "  config.model.patches = ml_collections.ConfigDict()\n",
        "  config.model.patches.size = [32, 32]\n",
        "  config.model.hidden_size = 1024\n",
        "  config.model.transformer = ml_collections.ConfigDict()\n",
        "  config.model.transformer.mlp_dim = 4096\n",
        "  config.model.transformer.num_heads = 16\n",
        "  config.model.transformer.num_layers = 24\n",
        "  config.model.transformer.attention_dropout_rate = 0.\n",
        "  config.model.transformer.dropout_rate = 0.\n",
        "  config.model.classifier = 'token'\n",
        "  config.model.representation_size = None\n",
        "\n",
        "  # Heteroscedastic\n",
        "  config.model.multiclass = True\n",
        "  config.model.temperature = 1.25\n",
        "  config.model.mc_samples = 1000\n",
        "  config.model.num_factors = 15\n",
        "  config.model.param_efficient = False\n",
        "\n",
        "  # BatchEnsemble\n",
        "  config.model.transformer.be_layers = (21, 22, 23)\n",
        "  config.model.transformer.ens_size = 3\n",
        "  config.model.transformer.random_sign_init = -0.5\n",
        "  config.model.transformer.ensemble_attention = False\n",
        "\n",
        "  # TODO(dusenberrymw): Remove the need to include this GP config.\n",
        "  # GP\n",
        "  config.model.use_gp = False\n",
        "  config.model.covmat_momentum = .999\n",
        "  config.model.ridge_penalty = 1.\n",
        "  config.model.mean_field_factor = -1.\n",
        "  return config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Yzbze0eRi9j"
      },
      "outputs": [],
      "source": [
        "num_classes = 1000\n",
        "config = get_finetuned_config()\n",
        "model = ub.models.vision_transformer_het_gp_be(\n",
        "    num_classes=num_classes, **config.model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zQjhZNpyYs9y"
      },
      "outputs": [],
      "source": [
        "@jax.jit\n",
        "def predict_fn(params, images, rng):\n",
        "  rng_dropout, rng_diag_noise, rng_standard_noise = jax.random.split(rng, num=3)\n",
        "  tiled_logits, _ = model.apply(\n",
        "      {'params': flax.core.freeze(params)},\n",
        "      images,\n",
        "      train=False,\n",
        "      rngs={\n",
        "          'dropout': rng_dropout,\n",
        "          'diag_noise_samples': rng_diag_noise,\n",
        "          'standard_norm_noise_samples': rng_standard_noise})\n",
        "  ens_logits = jnp.stack(jnp.split(tiled_logits, model.transformer.ens_size))\n",
        "  ens_probs = jax.nn.softmax(ens_logits)\n",
        "  avg_probs = jnp.mean(ens_probs, axis=0)  # Average over ensemble members.\n",
        "  return avg_probs  # Shape (batch_size, num_classes)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kugZWCfbW2lz"
      },
      "outputs": [],
      "source": [
        "checkpoint_path = \"gs://plex-paper/plex_vit_large_imagenet21k_to_imagenet.npz\"\n",
        "read_in_parallel = False\n",
        "checkpoint = checkpoint_utils.load_checkpoint(None, path=checkpoint_path,\n",
        "                                              read_in_parallel=read_in_parallel)\n",
        "params = checkpoint[\"opt\"][\"target\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nw0Gu5FCRsdI"
      },
      "source": [
        "### Make predictions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fAjGFhvck3gI"
      },
      "source": [
        "#### Single image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qIHbvoVjkhSk"
      },
      "outputs": [],
      "source": [
        "# Get a single image from https://www.tensorflow.org/datasets/catalog/imagenet_v2.\n",
        "# Direct URL: https://knowyourdata-tfds.withgoogle.com/#dataset=imagenet_v2\u0026tab=ITEM\u0026select=kyd%2Fimagenet_v2%2Flabel\u0026item=205%2F194ab2af3f5802ad12e1f4327d598743b01489c0.jpeg\n",
        "!wget --no-check-certificate \"https://knowyourdata-tfds.withgoogle.com/serve_image?\u0026id=205%2F194ab2af3f5802ad12e1f4327d598743b01489c0.jpeg\u0026segment_name=default_segment\u0026pipeline_id=20220622-e9a437\u0026cachebust=undsqlLh\u0026dataset=imagenet_v2\" -O image.jpg\n",
        "from IPython.display import Image, display\n",
        "display(Image(\"image.jpg\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AyPrGPoaaFbT"
      },
      "outputs": [],
      "source": [
        "# Load and preprocess image.\n",
        "def preprocess_fn(image):\n",
        "  # Note: The model was trained with this preprocessing.\n",
        "  x = tf.convert_to_tensor(image)\n",
        "  x = tf.io.decode_image(x, channels=3, expand_animations=False)\n",
        "  x = tf.image.resize(x, (384, 384))\n",
        "  x = tf.cast(x, tf.float32) / 255. * 2 - 1\n",
        "  return jnp.asarray(x)\n",
        "\n",
        "with open(\"image.jpg\", mode='rb') as f:\n",
        "  image_data = f.read()\n",
        "\n",
        "image = preprocess_fn(image_data)\n",
        "image.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FigL0s4BuwpY"
      },
      "outputs": [],
      "source": [
        "# Make a prediction.\n",
        "rng_eval = jax.random.fold_in(rng, 0)\n",
        "images = jnp.array([image])  # Create a batch of 1 image.\n",
        "probs = predict_fn(params, images, rng_eval)\n",
        "probs.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W5rOYy7eaKZv"
      },
      "outputs": [],
      "source": [
        "# Output top 5 predictions.\n",
        "all_top_preds = tf.keras.applications.imagenet_utils.decode_predictions(\n",
        "    probs, top=5)\n",
        "\n",
        "for top_preds in all_top_preds:\n",
        "  for _, pred_class_name, prob in top_preds:\n",
        "    print(f\"{float(prob):.6f} : {pred_class_name}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HvXDo8GDfKsD"
      },
      "source": [
        "#### Batch of images w/ multiple devices\n",
        "\n",
        "Here we demonstrate how to make predictions with the model on a batch of images using multiple devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uYPSMpNdXHfY"
      },
      "outputs": [],
      "source": [
        "def load_val_ds(dataset, split, batch_size, preprocess_eval_fn):\n",
        "  # NOTE: The data loader yields examples of shape\n",
        "  # (num_devices, batch_size/num_devices, ...), i.e., it splits the batch_size\n",
        "  # across the number of local devices, under the assumption that TPUs or\n",
        "  # multiple GPUs are used.\n",
        "  val_ds = input_utils.get_data(\n",
        "      dataset=dataset,\n",
        "      split=split,\n",
        "      rng=None,\n",
        "      process_batch_size=batch_size,\n",
        "      preprocess_fn=preprocess_eval_fn,\n",
        "      cache=False,\n",
        "      num_epochs=1,\n",
        "      repeat_after_batching=True,\n",
        "      shuffle=False,\n",
        "      prefetch_size=0,\n",
        "      drop_remainder=False,\n",
        "      data_dir=None)\n",
        "  return val_ds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ayMArKrKRX9n"
      },
      "outputs": [],
      "source": [
        "pp_eval = \"decode|resize(384)|value_range(-1, 1)|onehot(1000, key='label', key_result='labels')|keep(['image', 'labels'])\"\n",
        "preprocess_eval_fn = preprocess_spec.parse(\n",
        "    spec=pp_eval, available_ops=preprocess_utils.all_ops())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pPDwLwWxZ5Yz"
      },
      "outputs": [],
      "source": [
        "# https://www.tensorflow.org/datasets/catalog/imagenet_v2\n",
        "dataset = \"imagenet_v2\"\n",
        "tfds.builder(dataset).download_and_prepare()\n",
        "split = \"test\"\n",
        "batch_size = 64 * jax.local_device_count()\n",
        "val_ds = load_val_ds(dataset, split=split, batch_size=batch_size,\n",
        "                     preprocess_eval_fn=preprocess_eval_fn)\n",
        "val_ds.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tVLtDmnma_SW"
      },
      "outputs": [],
      "source": [
        "# Create a model function that works across multiple TPU devices or across\n",
        "# multiple GPUs for performance. The value for `in_axes` means that the `params`\n",
        "# argument for `predict_fn` will be copied to each device, the `images` will be\n",
        "# split (\"sharded\") across the devices along the first axis, and the `rng` will\n",
        "# be copied to each device. Note that this means that `images` should have shape\n",
        "# `(num_devices, batch_size/num_devices, h, w, c)` so that each device processes\n",
        "# a `(batch_size/num_devices, h, w, c)` chunk of the images. The `params` and\n",
        "# `rng` will be the same as in the \"Singe image\" example up above.\n",
        "pmapped_predict_fn = jax.pmap(predict_fn, in_axes=(None, 0, None))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IVS8wRpMal_J"
      },
      "outputs": [],
      "source": [
        "batch = next(val_ds.as_numpy_iterator())\n",
        "rng_eval = jax.random.fold_in(rng, 0)\n",
        "probs = pmapped_predict_fn(params, batch[\"image\"], rng_eval)\n",
        "# Note that probs is of shape (num_devices, batch_size, num_classes).\n",
        "probs.shape, probs.device_buffers[0].device()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-TeTE5GSNN3C"
      },
      "outputs": [],
      "source": [
        "def get_and_reshape(x):\n",
        "  # Fetch probs from all devices to CPU and reshape to (batch_size, ...).\n",
        "  return jnp.reshape(jax.device_get(x), (-1,) + x.shape[2:])\n",
        "\n",
        "images = get_and_reshape(batch[\"image\"])\n",
        "all_top_preds = tf.keras.applications.imagenet_utils.decode_predictions(\n",
        "    get_and_reshape(probs), top=5)\n",
        "labels = tf.keras.applications.imagenet_utils.decode_predictions(\n",
        "    get_and_reshape(batch[\"labels\"]), top=1)\n",
        "\n",
        "# Only show 10 images.\n",
        "for _, image, top_preds, label in zip(range(10), images, all_top_preds, labels):\n",
        "  plt.figure(figsize=(4, 4))\n",
        "  plt.imshow(image * .5 + .5)\n",
        "  plt.axis('off')\n",
        "  plt.show()\n",
        "\n",
        "  correct_class_name = label[0][1]\n",
        "  for _, pred_class_name, prob in top_preds:\n",
        "    print(f\"{float(prob):.6f} : {pred_class_name}\")\n",
        "  print(f\"Correct class: {correct_class_name}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S3VjVtYkNlHp"
      },
      "source": [
        "## Reliability"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q9RwvXOaxsgA"
      },
      "source": [
        "### Uncertainty\n",
        "\n",
        "\n",
        "To be announced!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "szwF_5fZ0rGt"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MXN33Y-bxue5"
      },
      "source": [
        "### Robust Generalization\n",
        "\n",
        "Here we demonstrate a *covariate shift* problem by adding ImageNet-C-style Gaussian noise ([Hendrycks \u0026 Gimpel, 2019](http://arxiv.org/abs/1903.12261)) to an input image and showing the model's predictions as the noise increases. In this type of problem, we view shifted examples as \"noisy\", but close enough to the distribution of training examples that we desire our model to be robust to the noise and still make strong predictions. Corruption levels 1-5 correspond to those in ImageNet-C, and we add additional levels above those. We see that Plex models are able to make confident predictions under large amounts of noise. Full evaluation results are in the paper."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y2teWA8dxt61"
      },
      "outputs": [],
      "source": [
        "# Define a Gaussin noise function to form ImageNet-C-style Gaussian noise\n",
        "# corruptions.\n",
        "def gaussian_noise(x, severity, rng):\n",
        "  severity_scales = [.08, .12, 0.18, 0.26, 0.38, 0.6, 1.]\n",
        "  assert severity in range(1, len(severity_scales) + 1)\n",
        "  scale = severity_scales[severity - 1]\n",
        "  x = x / 255.\n",
        "  x = jnp.clip(x + scale * jax.random.normal(rng, shape=x.shape), 0, 1) * 255\n",
        "  return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kkkdvlUbvpPb"
      },
      "outputs": [],
      "source": [
        "# Load and preprocess image.\n",
        "def preprocess_fn(image, severity=None, rng=None):\n",
        "  # Note: The model was trained with this preprocessing.\n",
        "  x = tf.convert_to_tensor(image)\n",
        "  x = tf.io.decode_image(x, channels=3, expand_animations=False)\n",
        "  x = tf.cast(tf.image.resize(x, (384, 384)), tf.float32)\n",
        "  x = jnp.asarray(x)\n",
        "  if severity is not None:\n",
        "    x = gaussian_noise(x, severity, rng)\n",
        "  x = x / 255. * 2 - 1\n",
        "  return x\n",
        "\n",
        "with open(\"image.jpg\", mode='rb') as f:\n",
        "  image_data = f.read()\n",
        "\n",
        "image = preprocess_fn(image_data)\n",
        "corrupted_images = [preprocess_fn(image_data, s, jax.random.fold_in(rng, 0))\n",
        "                    for s in range(1, 8)]\n",
        "images = jnp.array([image] + corrupted_images)\n",
        "images.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pjIx0wpavpPc"
      },
      "outputs": [],
      "source": [
        "# Make predictions.\n",
        "rng_eval = jax.random.fold_in(rng, 0)\n",
        "probs = predict_fn(params, images, rng_eval)\n",
        "probs.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xgwkflBbvpPc"
      },
      "outputs": [],
      "source": [
        "# Output top 5 predictions.\n",
        "all_top_preds = tf.keras.applications.imagenet_utils.decode_predictions(\n",
        "    probs, top=5)\n",
        "\n",
        "# Only show 10 images.\n",
        "for i, (image, top_preds) in enumerate(zip(images, all_top_preds)):\n",
        "  plt.figure(figsize=(4, 4))\n",
        "  plt.imshow(image * .5 + .5)\n",
        "  plt.axis('off')\n",
        "  plt.show()\n",
        "\n",
        "  if i \u003e 0:\n",
        "    print(f\"Corruption level: {i}\")\n",
        "  for _, pred_class_name, prob in top_preds:\n",
        "    print(f\"{float(prob):.6f} : {pred_class_name}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WmQBjpdSxwb8"
      },
      "source": [
        "### Adaptation\n",
        "\n",
        "Here we demonstrate zero-shot out-of-distribution (OOD) detection using the upstream pretrained model and the relative Mahalanobis distance metric ([Ren et al., 2021](http://arxiv.org/abs/2106.09022)). In zero-shot OOD detection, the goal is to take a fixed model that was pretrained on dataset A and use it to distinguish between in-distributions samples from dataset B and OOD samples from dataset C, all without training the model further on datset B or C. We see that pretrained Plex without any finetuning is able to achieve a strong separation between in and out of distribution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-iWtPSv7Yh9y"
      },
      "outputs": [],
      "source": [
        "# Free up RAM.\n",
        "del probs, batch, params, checkpoint\n",
        "\n",
        "import gc\n",
        "gc.collect()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hZHOW3CO5l4G"
      },
      "outputs": [],
      "source": [
        "def get_pretrained_config():\n",
        "  # From `https://github.com/google/uncertainty-baselines/blob/main/baselines/jft/experiments/vit_l32_plex_finetune.py`.\n",
        "  # TODO(dusenberrymw): Clean up this config.\n",
        "  config = ml_collections.ConfigDict()\n",
        "  config.model = ml_collections.ConfigDict()\n",
        "  config.model.patches = ml_collections.ConfigDict()\n",
        "  config.model.patches.size = [32, 32]\n",
        "  config.model.hidden_size = 1024\n",
        "  config.model.transformer = ml_collections.ConfigDict()\n",
        "  config.model.transformer.mlp_dim = 4096\n",
        "  config.model.transformer.num_heads = 16\n",
        "  config.model.transformer.num_layers = 24\n",
        "  config.model.transformer.attention_dropout_rate = 0.\n",
        "  config.model.transformer.dropout_rate = 0.\n",
        "  config.model.classifier = 'token'\n",
        "  config.model.representation_size = None\n",
        "\n",
        "  # BatchEnsemble\n",
        "  config.model.transformer.be_layers = (21, 22, 23)\n",
        "  config.model.transformer.ens_size = 3\n",
        "  config.model.transformer.random_sign_init = -0.5\n",
        "  config.model.transformer.ensemble_attention = False\n",
        "\n",
        "  return config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFYYjLfI5YFe"
      },
      "outputs": [],
      "source": [
        "num_classes = 21843\n",
        "config = get_pretrained_config()\n",
        "pretrained_model = ub.models.vision_transformer_be(\n",
        "    num_classes=num_classes, **config.model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oWU4Y5GHyeOA"
      },
      "outputs": [],
      "source": [
        "@jax.jit\n",
        "def representation_fn(params, images, rng):\n",
        "  rng_dropout, rng_diag_noise, rng_standard_noise = jax.random.split(rng, num=3)\n",
        "  _, out = pretrained_model.apply(\n",
        "      {'params': flax.core.freeze(params)},\n",
        "      images,\n",
        "      train=False,\n",
        "      rngs={\n",
        "          'dropout': rng_dropout,\n",
        "          'diag_noise_samples': rng_diag_noise,\n",
        "          'standard_norm_noise_samples': rng_standard_noise})\n",
        "  representations = out[\"pre_logits\"]\n",
        "  ens_representations = jnp.stack(jnp.split(representations,\n",
        "                                            model.transformer.ens_size), axis=1)\n",
        "  return ens_representations  # Shape (batch_size, ens_sizen, um_classes).\n",
        "\n",
        "# Create a model function that works across multiple TPU devices or across\n",
        "# multiple GPUs for performance. The value for `in_axes` means that the `params`\n",
        "# argument for `predict_fn` will be copied to each device, the `images` will be\n",
        "# split (\"sharded\") across the devices along the first axis, and the `rng` will\n",
        "# be copied to each device. Note that this means that `images` should have shape\n",
        "# `(num_devices, batch_size/num_devices, h, w, c)` so that each device processes\n",
        "# a `(batch_size/num_devices, h, w, c)` chunk of the images. The `params` and\n",
        "# `rng` will be the same as in the \"Singe image\" example up above.\n",
        "pmapped_representation_fn = jax.pmap(representation_fn, in_axes=(None, 0, None))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZpeKRlOukbpL"
      },
      "outputs": [],
      "source": [
        "@functools.partial(jax.jit, backend='cpu')\n",
        "def init(rng):\n",
        "  image_size = (384, 384, 3)  # Note the larger input size.\n",
        "  logging.info('image_size = %s', image_size)\n",
        "  dummy_input = jnp.zeros((1,) + image_size, jnp.float32)\n",
        "  params = flax.core.unfreeze(pretrained_model.init(rng, dummy_input,\n",
        "                                                    train=False))['params']\n",
        "\n",
        "  # Set bias in the head to a low value, such that loss is small initially.\n",
        "  params['batchensemble_head']['bias'] = jnp.full_like(\n",
        "      params['batchensemble_head']['bias'], config.get('init_head_bias', 0))\n",
        "\n",
        "  # init head kernel to all zeros for fine-tuning\n",
        "  if config.get('model_init'):\n",
        "    params['batchensemble_head']['kernel'] = jnp.full_like(\n",
        "        params['batchensemble_head']['kernel'], 0)\n",
        "\n",
        "  return params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Inw6Nvxw6r4h"
      },
      "outputs": [],
      "source": [
        "checkpoint_path = \"gs://plex-paper/plex_vit_large_imagenet21k.npz\"\n",
        "read_in_parallel = False\n",
        "loaded_params = checkpoint_utils.load_checkpoint(\n",
        "    None, path=checkpoint_path, read_in_parallel=read_in_parallel)\n",
        "\n",
        "rng, rng_init = jax.random.split(rng)\n",
        "init_params = init(rng_init)\n",
        "\n",
        "pretrained_params = checkpoint_utils.restore_from_pretrained_params(\n",
        "    init_params, loaded_params, config.model.representation_size,\n",
        "    config.model.classifier, reinit_params=None)\n",
        "\n",
        "del loaded_params, init_params  # Free up memory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3l3FcAH0cutN"
      },
      "outputs": [],
      "source": [
        "def get_and_reshape(x):\n",
        "  # Fetch probs from all devices to CPU and reshape to (batch_size, ...).\n",
        "  return jnp.reshape(jax.device_get(x), (-1,) + x.shape[2:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oubRHRls7D1Q"
      },
      "outputs": [],
      "source": [
        "# https://www.tensorflow.org/datasets/catalog/imagenet_v2\n",
        "dataset = \"imagenet_v2\"\n",
        "tfds.builder(dataset).download_and_prepare()\n",
        "batch_size = 64 * jax.local_device_count()\n",
        "split = \"test\"\n",
        "\n",
        "pp_eval = f\"decode|resize(384)|value_range(-1, 1)|onehot(1000, key='label', key_result='labels')|keep(['image', 'labels', 'id'])\"\n",
        "preprocess_eval_fn = preprocess_spec.parse(\n",
        "    spec=pp_eval, available_ops=preprocess_utils.all_ops())\n",
        "\n",
        "val_ds = load_val_ds(dataset, split=split, batch_size=batch_size,\n",
        "                     preprocess_eval_fn=preprocess_eval_fn)\n",
        "\n",
        "in_dist_representations = []\n",
        "in_dist_labels = []\n",
        "masks = []\n",
        "\n",
        "# NOTE: given more compute, use the entire dataset instead.\n",
        "val_ds = val_ds.shuffle(256, seed=42).take(int(1024 / batch_size))\n",
        "for i, batch in enumerate(val_ds.as_numpy_iterator()):\n",
        "  rng_eval = jax.random.fold_in(rng, 0)\n",
        "  representation = pmapped_representation_fn(pretrained_params, batch[\"image\"],\n",
        "                                             rng_eval)\n",
        "  in_dist_representations.append(get_and_reshape(representation))\n",
        "  masks.append(get_and_reshape(batch[\"mask\"]))\n",
        "  in_dist_labels.append(get_and_reshape(jnp.argmax(batch[\"labels\"], axis=-1)))\n",
        "\n",
        "mask = jnp.concatenate(jax.device_get(masks))\n",
        "in_dist_representations = jnp.concatenate(in_dist_representations)[mask == 1]\n",
        "in_dist_labels = jnp.concatenate(in_dist_labels)[mask == 1]\n",
        "in_dist_representations.shape, in_dist_labels.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JO9ZL1LuXQW8"
      },
      "outputs": [],
      "source": [
        "ens_means, ens_covs = [], []\n",
        "ens_means_background, ens_covs_background = [], []\n",
        "for m in range(in_dist_representations.shape[1]):\n",
        "  means, cov = ood_utils.compute_mean_and_cov(\n",
        "      in_dist_representations[:, m],\n",
        "      in_dist_labels,\n",
        "      class_ids=jnp.unique(in_dist_labels))\n",
        "  ens_means.append(means)\n",
        "  ens_covs.append(cov)\n",
        "\n",
        "  means_bg, cov_bg = ood_utils.compute_mean_and_cov(\n",
        "      in_dist_representations[:, m],\n",
        "      jnp.zeros_like(in_dist_labels),\n",
        "      class_ids=jnp.array([0]))\n",
        "  ens_means_background.append(means_bg)\n",
        "  ens_covs_background.append(cov_bg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SiSyTb5wjYNL"
      },
      "outputs": [],
      "source": [
        "ens_in_dist_rmaha_distances = []\n",
        "for m in range(len(ens_means)):\n",
        "  distances = ood_utils.compute_mahalanobis_distance(\n",
        "      in_dist_representations[:, m], ens_means[m], ens_covs[m])\n",
        "  distances_bg = ood_utils.compute_mahalanobis_distance(\n",
        "      in_dist_representations[:, m], ens_means_background[m],\n",
        "      ens_covs_background[m])\n",
        "  rmaha_distances = jnp.min(distances, axis=-1) - distances_bg[:, 0]\n",
        "  ens_in_dist_rmaha_distances.append(rmaha_distances)\n",
        "\n",
        "in_dist_rmaha_distances = jnp.mean(jnp.array(ens_in_dist_rmaha_distances),\n",
        "                                   axis=0)\n",
        "del ens_in_dist_rmaha_distances\n",
        "in_dist_rmaha_distances.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ITX6Iz9rW1Bv"
      },
      "outputs": [],
      "source": [
        "# https://www.tensorflow.org/datasets/catalog/fashion_mnist\n",
        "dataset = \"fashion_mnist\"\n",
        "tfds.builder(dataset).download_and_prepare()\n",
        "batch_size = 64 * jax.local_device_count()\n",
        "split = \"test\"\n",
        "\n",
        "pp_eval = f\"decode|resize(384)|value_range(-1, 1)|keep(['image'])\"\n",
        "preprocess_eval_fn = preprocess_spec.parse(\n",
        "    spec=pp_eval, available_ops=preprocess_utils.all_ops())\n",
        "\n",
        "val_ds = load_val_ds(dataset, split=split, batch_size=batch_size,\n",
        "                     preprocess_eval_fn=preprocess_eval_fn)\n",
        "\n",
        "ood_representations = []\n",
        "masks = []\n",
        "\n",
        "# NOTE: given more compute, use the entire dataset instead.\n",
        "val_ds = val_ds.shuffle(256, seed=42).take(int(1024 / batch_size))\n",
        "for i, batch in enumerate(val_ds.as_numpy_iterator()):\n",
        "  rng_eval = jax.random.fold_in(rng, 0)\n",
        "  representation = pmapped_representation_fn(pretrained_params, batch[\"image\"],\n",
        "                                             rng_eval)\n",
        "  ood_representations.append(get_and_reshape(representation))\n",
        "  masks.append(get_and_reshape(batch[\"mask\"]))\n",
        "\n",
        "mask = jnp.concatenate(masks)\n",
        "ood_representations = jnp.concatenate(ood_representations)[mask == 1]\n",
        "ood_representations.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zg2Z-WWEXOl_"
      },
      "outputs": [],
      "source": [
        "ens_ood_rmaha_distances = []\n",
        "for m in range(len(ens_means)):\n",
        "  distances = ood_utils.compute_mahalanobis_distance(\n",
        "      ood_representations[:, m], ens_means[m], ens_covs[m])\n",
        "  distances_bg = ood_utils.compute_mahalanobis_distance(\n",
        "      ood_representations[:, m], ens_means_background[m],\n",
        "      ens_covs_background[m])\n",
        "  rmaha_distances = jnp.min(distances, axis=-1) - distances_bg[:, 0]\n",
        "  ens_ood_rmaha_distances.append(rmaha_distances)\n",
        "\n",
        "ood_rmaha_distances = jnp.mean(jnp.array(ens_ood_rmaha_distances),\n",
        "                                   axis=0)\n",
        "del ens_ood_rmaha_distances\n",
        "ood_rmaha_distances.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9b8rDdFRmv7c"
      },
      "outputs": [],
      "source": [
        "labels = jnp.concatenate((jnp.zeros_like(in_dist_rmaha_distances),\n",
        "                          jnp.ones_like(ood_rmaha_distances)))\n",
        "scores = jnp.concatenate((in_dist_rmaha_distances, ood_rmaha_distances))\n",
        "aucroc = sklearn.metrics.roc_auc_score(labels, scores)\n",
        "print(aucroc)\n",
        "\n",
        "plt.hist([in_dist_rmaha_distances, ood_rmaha_distances], bins=100, density=True,\n",
        "         label=[\"in-dist\", \"ood\"])\n",
        "plt.legend()\n",
        "plt.title(f'aucroc={aucroc}')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F7n5xqhpupU3"
      },
      "outputs": [],
      "source": [
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FXH-hUpnJIFz"
      },
      "source": [
        "## Extras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A0l6FwKW-91b"
      },
      "source": [
        "### Export to TensorFlow for serving, embedded devices, TF.js, etc.\n",
        "\n",
        "To be announced!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yXoV8tIS_DPc"
      },
      "outputs": [],
      "source": [
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "ViT-Plex Demo",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/plex/plex_vit_demo.ipynb?workspaceId=dusenberrymw:fig-export-ub-1870-change-83::citc",
          "timestamp": 1672347114542
        },
        {
          "file_id": "1SLwhsSGHyBy7vfVDZOfJPh5uDxzRDJvt",
          "timestamp": 1657739775185
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
